import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
import time
from PIL import Image
means = torch.Tensor([0.4914, 0.4822, 0.4465])
stds = torch.Tensor([0.2470, 0.2435, 0.2616])
transforms = transforms.Compose([transforms.RandomHorizontalFlip(),
transforms.RandomRotation(10),
transforms.ToTensor(),
transforms.Normalize(means, stds)])
dl = torch.utils.data.DataLoader(datasets.CIFAR10("datasets/", download=True,
train=True, transform=transforms),
batch_size=1000,
shuffle=True)
dl_test = torch.utils.data.DataLoader(datasets.CIFAR10("datasets/", download=True,
train=False, transform=transforms),
batch_size=100,
shuffle=True)
stds = stds.unsqueeze(-1).unsqueeze(-1).expand(-1, 32, 32)
means = means.unsqueeze(-1).unsqueeze(-1).expand(-1, 32, 32)
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 32, 3, padding=1) # 32x32 -> 32x32
self.conv2 = nn.Conv2d(32, 32, 3, padding=1) # 32x32 -> 16x16
self.conv3 = nn.Conv2d(32, 64, 3, padding=1) # 16x16 -> 16x16
self.conv4 = nn.Conv2d(64, 64, 3, padding=1) # 16x16 -> 8x8
self.conv5 = nn.Conv2d(64, 128, 3, padding=1) # 8x8 -> 4x4
self.pool = nn.MaxPool2d(2); self.relu = nn.ReLU(); self.logSoftmax = nn.LogSoftmax(1)
self.batchnorm1 = nn.BatchNorm2d(32)
self.batchnorm2 = nn.BatchNorm2d(64)
self.fc1 = nn.Linear(128 * 4 * 4, 300); self.fc2 = nn.Linear(300, 10)
self.dropout = nn.Dropout(0.5)
self.times, self.losses, self.accuracies = [], [], []
def forward(self, x):
x = self.batchnorm1(self.pool(self.relu(self.conv2(self.relu(self.conv1(x))))))
x = self.batchnorm2(self.pool(self.relu(self.conv4(self.relu(self.conv3(x))))))
x = self.pool(self.relu(self.conv5(x)))
x = self.dropout(x.contiguous().view(-1, 128 * 4 * 4))
x = self.dropout(self.relu(self.fc1(x)))
return self.logSoftmax(self.fc2(x))
def fit(self, epochs):
optimizer = optim.Adam(self.parameters(), lr=0.003, weight_decay=1e-5)
count, lossFunction = 0, nn.NLLLoss()
lastTime, initialTime = (self.times[-1] if len(self.times) > 0 else 0), time.time()
for epoch in range(epochs):
for imgs, labels in dl:
count += 1; optimizer.zero_grad(); imgs, labels = imgs.cuda(), labels.cuda()
loss = lossFunction(self(imgs), labels); loss.backward(); optimizer.step()
if count % 30 == 0:
self.eval()
test_imgs, test_labels = next(iter(dl_test));self.losses.append(loss.item())
self.accuracies.append((torch.argmax(self(test_imgs.cuda()), dim=1) == test_labels.cuda()).sum())
self.times.append(lastTime + time.time()-initialTime); self.train()
print(f"\rProgress: {np.round(100*count/(epochs*len(dl)))}%, loss: {self.losses[-1]}, accuracy: {self.accuracies[-1]}/100 ", end="")
return torch.Tensor(self.losses), torch.Tensor(self.accuracies), torch.Tensor(self.times)
net = Net().cuda()
net.load_state_dict(torch.load("models/cnn-standard.pth"))
#losses, accuracies, times = net.fit(30); torch.save(net.state_dict(), "models/cnn-standard.pth")
# expects x of shape (#samples, 3, 32, 32)
def pickOut(self, x, convLayer=3):
x = self.relu(self.conv1(x))
if convLayer == 1: return x
x = self.relu(self.conv2(x))
if convLayer == 2: return x
x = self.batchnorm1(self.pool(x))
x = self.relu(self.conv3(x))
if convLayer == 3: return x
x = self.relu(self.conv4(x))
if convLayer == 4: return x
x = self.batchnorm2(self.pool(x))
x = self.relu(self.conv5(x))
if convLayer == 5: return x
x = self.pool(x)
x = self.dropout(x.contiguous().view(-1, 128 * 4 * 4))
x = self.dropout(self.relu(self.fc1(x)))
return self.logSoftmax(self.fc2(x))
Net.pickOut = pickOut
# expects image of shape (1, 3, 32, 32)
def displayConvOutputs(self, imgs):
for convLayer in range(5):
output = self.pickOut(imgs.cuda(), convLayer + 1).cpu().detach()
print(f"Conv layer: {convLayer + 1}")
dim = output.shape[1]; plt.figure(num=None, figsize=(10, 4/dim*2.5*16), dpi=350)
for i in range(dim):
plt.subplot(4, dim/4, i+1); plt.axis("off"); plt.imshow(output[0][i])
plt.show()
Net.displayConvOutputs = displayConvOutputs
# expects image of shape (1, 3, 32, 32)
def graphPredictions(self, img, orig, means=0, stds=1):
plt.figure(num=None, figsize=(10, 3), dpi=350)
plt.subplot(1, 3, 1); plt.bar(categories, torch.exp(self(img.cuda())[0]).detach().cpu())
plt.xticks(rotation='vertical')
plt.subplot(1, 3, 2); plt.imshow((img[0].cpu() * stds + means).permute(1, 2, 0).detach())
plt.subplot(1, 3, 3)#; plt.imshow((orig[0].cpu() * stds + means).permute(1, 2, 0).detach())
plt.imshow((torch.abs(img[0].cpu()-orig[0]) * stds + means).permute(1, 2, 0).detach())
plt.show()
Net.graphPredictions = graphPredictions
# expects image of shape (1, 3, 32, 32)
def predict(self, img):
return categories[torch.argmax(self(img.cuda()), dim=1)[0]]
Net.predict = predict
def overStride(self, img, originalImage, convLayer, stride, optimizer):
losses = []
for i in range(1000):
optimizer.zero_grad(); loss = net.pickOut(img, convLayer)[0][::stride].std()
loss.backward(); losses.append(loss.item()); optimizer.step()
net.graphPredictions(img, originalImage, means, stds); return losses
Net.overStride = overStride
def overLayer(self, convLayer):
imgs, labels = next(iter(dl)); orig = imgs[0:1]; img = orig.cuda().requires_grad_(True)
print(f"Category: {categories[labels[0]]}")
optimizer = optim.Adam([img], lr=0.003); losses = []; print("Original:")
net.graphPredictions(img, orig, means, stds)
for i in range(6):
print(f"Layer: {convLayer}, stride: {int(32/2**i)}")
losses.extend(self.overStride(img, orig, convLayer, int(32/2**i), optimizer))
plt.figure(num=None, figsize=(10, 3), dpi=350); plt.plot(losses); plt.grid(True); plt.show()
self.displayConvOutputs(img)
Net.overLayer = overLayer
categories = ["plane", "auto", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"]
totalAccuracy = 0; net.eval()
for test_imgs, test_labels in dl_test:
totalAccuracy += (torch.argmax(net(test_imgs.cuda()), dim=1) == test_labels.cuda()).sum()
totalAccuracy/len(dl_test)
net.overLayer(1)
net.overLayer(1)
net.overLayer(1)
net.overLayer(1)
net.overLayer(2)
net.overLayer(2)
net.overLayer(2)
net.overLayer(2)
net.overLayer(3)
net.overLayer(3)
net.overLayer(3)
net.overLayer(3)
net.overLayer(4)
net.overLayer(4)
net.overLayer(4)
net.overLayer(4)
net.overLayer(5)
net.overLayer(5)
net.overLayer(5)
net.overLayer(5)